'''
Faraday Penetration Test IDE
Copyright (C) 2013  Infobyte LLC (http://www.infobytesec.com/)
See the file 'doc/LICENSE' for the license information

'''
import pytest
from faraday.server.models import (
    CommandObject,
    PolicyViolation,
    PolicyViolationVulnerabilityAssociation,
    Reference,
    ReferenceVulnerabilityAssociation,
)


def test_standard_host_vuln_hostnames(vulnerability_factory,
                                      host_with_hostnames,
                                      session):
    vuln = vulnerability_factory.create(host=host_with_hostnames,
                                        service=None, workspace=host_with_hostnames.workspace)
    session.commit()

    assert set(vuln.hostnames) == set(host_with_hostnames.hostnames)


def test_standard_vuln_service_hostnames(vulnerability_factory,
                                         service_factory,
                                         host_with_hostnames,
                                         session):
    service = service_factory.create(host=host_with_hostnames, workspace=host_with_hostnames.workspace)
    vuln = vulnerability_factory.create(service=service, host=None, workspace=host_with_hostnames.workspace)
    session.commit()
    assert set(vuln.hostnames) == set(host_with_hostnames.hostnames)


def test_web_vuln_hostnames(vulnerability_web_factory,
                            service_factory,
                            host_with_hostnames,
                            session):
    service = service_factory.create(host=host_with_hostnames, workspace=host_with_hostnames.workspace)
    vuln = vulnerability_web_factory.create(service=service, workspace=host_with_hostnames.workspace)
    session.commit()
    assert set(vuln.hostnames) == set(host_with_hostnames.hostnames)


def test_code_vuln(vulnerability_code, session):
    session.commit()
    # Source code vulnerabilities have no hostnames
    assert vulnerability_code.hostnames == []


class TestReferences:

    field_name = 'references'
    instances_field_name = 'reference_instances'
    model = Reference
    intermediate_model = ReferenceVulnerabilityAssociation
    intermediate_field = 'reference'

    @pytest.fixture(autouse=True)
    def load_data(self, vulnerability_factory, session):
        self.vuln = vulnerability_factory.create()
        self.vuln2 = vulnerability_factory.create(
            workspace=self.vuln.workspace)
        self.vuln_different_ws = vulnerability_factory.create()
        self.vulns = [self.vuln, self.vuln2, self.vuln_different_ws]
        session.commit()
        assert self.vuln.workspace_id != self.vuln_different_ws.workspace_id

    @pytest.fixture
    def child(self, session):
        child = self.model('CVE-2017-1234', self.vuln.workspace_id)
        session.add(child)
        session.commit()
        return child

    def childs(self, vuln=None, instance=False):
        vuln = vuln or self.vuln
        field = self.instances_field_name if instance else self.field_name
        return getattr(vuln, field)

    def test_empty_references(self):
        for vuln in self.vulns:
            assert isinstance(self.childs(vuln, True), set)
            assert len(vuln.references) == 0

    def test_add_references(self, session):
        self.childs().add('CVE-2017-1234')
        session.add(self.vuln)
        assert session.new
        session.commit()
        assert self.model.query.count() == 1
        ref = self.model.query.first()
        assert ref.name == 'CVE-2017-1234'
        assert ref.workspace_id == self.vuln.workspace_id

        # Re-adding the reference shouldn't do nothing
        self.childs().add('CVE-2017-1234')
        assert self.model.query.count() == 1

    def test_add_existing_child(self, session, child):
        for vuln in [self.vuln, self.vuln2]:
            self.childs(vuln).add('CVE-2017-1234')
            session.commit()
            assert self.model.query.count() == 1
            assert len(self.childs(vuln)) == 1
            assert self.childs(vuln, True).pop().id == child.id

    @pytest.mark.parametrize('orphan_vuln', [True, False],
                             ids=['with_orphan_child',
                                  'with_used_child'])
    def test_add_existing_from_other_workspace(self, session, child,
                                               orphan_vuln):
        if not orphan_vuln:
            self.childs().add(child.name)
            session.commit()
        self.childs(self.vuln_different_ws).add(child.name)
        session.add(self.vuln_different_ws)
        session.commit()
        assert self.model.query.count() == 2
        assert len(self.childs(self.vuln_different_ws)) == 1
        new_child = self.childs(self.vuln_different_ws, True).pop()
        assert (new_child.workspace_id
                == self.vuln_different_ws.workspace_id)
        assert new_child.id != child.id

    def test_remove_reference(self, session, child):
        session.add(self.vuln)
        self.childs().add(child.name)
        session.commit()
        self.childs().remove(child.name)
        session.commit()
        filters = {
            'vulnerability': self.vuln,
            self.intermediate_field: child
        }
        assert self.intermediate_model.query.filter_by(**filters).count() == 0

    @pytest.mark.skip('not implemented yet')
    def test_removes_orphan(self):
        pass

    @pytest.mark.parametrize('previous_childs,new_childs', [
        (set(), set()),
        (set(), {'CVE-2017-1234'}),
        ({'CVE-2017-1234'}, set()),
        ({'CVE-2017-1234'}, {'CVE-2017-1234'}),
        ({'CVE-2017-1234'}, {'CVE-2017-4321'}),
        ({'CVE-2017-1234'}, {'CVE-2017-1234', 'CVE-2017-4321'}),
    ], ids=[
        '{} -> {}',
        '{} -> {a}',
        '{a} -> {}',
        '{a} -> {a}',
        '{a} -> {b}',
        '{a} -> {a, b}',
    ])
    def test_direct_assignation(self, session, previous_childs, new_childs):
        for ref in previous_childs:
            self.childs().add(ref)
        session.add(self.vuln)
        session.commit()
        setattr(self.vuln, self.field_name, new_childs)
        session.commit()
        session.refresh(self.vuln)
        assert self.childs() == new_childs

    def test_create_workspace_and_vuln_with_childs(
            self, session, vulnerability_factory, child):
        # This should not raise an error since the workspace will be propagated
        # to the childs as created vuln has a workspace and is persisted
        vuln = vulnerability_factory.create()
        setattr(vuln, self.field_name, {'CVE-2017-1234'})

    def test_create_vuln_with_childs(self, session, vulnerability_factory):
        vuln = vulnerability_factory.build()
        session.add(vuln.workspace)
        session.commit()
        assert vuln.workspace.id
        setattr(vuln, self.field_name, {'CVE-2017-1234'})
        session.add(vuln)
        session.commit()
        assert len(self.childs(vuln)) == 1
        assert self.model.query.count() == 1


class TestPolicyViolations(TestReferences):
    field_name = 'policy_violations'
    instances_field_name = 'policy_violation_instances'
    model = PolicyViolation
    intermediate_model = PolicyViolationVulnerabilityAssociation
    intermediate_field = 'policy_violation'


class TestCommandProperties:
    def test_no_creator_command(self, vulnerability, session):
        session.commit()
        assert vulnerability.creator_command_id is None
        assert vulnerability.creator_command_tool is None

    def test_creator_command(self, workspace, vulnerability,
                             empty_command_factory, session):
        command = empty_command_factory.create(workspace=workspace)
        session.add(vulnerability)
        session.flush()
        assert vulnerability.id is not None
        assert command.id is not None
        assert vulnerability.workspace is workspace
        assert command.workspace is workspace

        CommandObject.create(
            vulnerability,
            command
        )
        for _ in range(5):
            new_command = empty_command_factory.create(workspace=workspace)
            session.flush()
            session.add(CommandObject(vulnerability,
                                      command=new_command,
                                      workspace=workspace,
                                      created_persistent=False))
        session.commit()
        assert vulnerability.creator_command_id == command.id
        assert vulnerability.creator_command_tool == command.tool

    def test_different_object_type(self, vulnerability, workspace,
                                   empty_command_factory, session):
        command = empty_command_factory.create(workspace=workspace)
        session.add(vulnerability)
        session.flush()
        invalid_co = CommandObject.create(vulnerability, command)
        invalid_co.object_type = 'host'
        session.add(invalid_co)
        session.commit()

        command = empty_command_factory.create(workspace=workspace)
        session.flush()

        CommandObject.create(
            vulnerability,
            command
        )
        for _ in range(5):
            new_command = empty_command_factory.create(workspace=workspace)
            session.flush()
            session.add(CommandObject(vulnerability,
                                      command=new_command,
                                      workspace=workspace,
                                      created_persistent=False))
        session.commit()
        assert vulnerability.creator_command_id == command.id
        assert vulnerability.creator_command_tool == command.tool
